#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Author: renjin@bit.edu.cn
# @Date  : 2024-10-30


"""
【节点名称】：
    YOLOv11DetJetsonTensorRTNode
【依赖项安装】：
    pip install spirems
    pip install pycuda -i https://pypi.tuna.tsinghua.edu.cn/simple
    cd <path-to-SpireCV>/spirecv/algorithm/common_det/YOLOv11DetJetsonTensorRTLib
    mkdir build
    cd build
    cmake ..
    make
【订阅类型】：
    sensor_msgs::CompressedImage （输入图像）
【发布类型】：
    spirecv_msgs::2DTargets （检测结果）
    sensor_msgs::CompressedImage （可视化结果，需借助可视化工具）
    std_msgs::Boolean （如果输入节点是数据集，则检测完成发布该话题让输入节点继续工作）
【构造参数说明】：
    parameter_file (str): 全局参数文件
    sms_shutdown (bool): 是否接收全局关闭信号，如果需要长期后台执行，建议设置为False
    specified_input_topic (str): 指定输入的话题地址
    specified_output_topic (str): 指定输出的话题地址
    realtime_det (bool): 是否是实时检测器，设置为True可以降低延迟，但可能会产生丢帧
【节点参数】：
    confidence (float): 目标得分阈值
    nms_thresh (float): NMS后处理参数
    dataset_name (str): 数据集名称
    pt_model (str): 加载模型名称
【备注】：
    无
"""

import ctypes
import os
import shutil
import random
import sys
import threading
import time
import cv2
import numpy as np
import pycuda.autoinit  # noqa: F401
import pycuda.driver as cuda
import tensorrt as trt
import threading
from queue import Queue
import numpy as np
import time
import argparse
from typing import Union
from spirems import Subscriber, Publisher, cvimg2sms, sms2cvimg, def_msg, QoS, BaseNode, get_extra_args
from spirems.mod_helper import download_model
import platform
import os


POSE_NUM = 17 * 3
DET_NUM = 6
SEG_NUM = 32


def trans_det_results(result_boxes, result_scores, result_classid, w, h, categories, roi=None):
    sms_results = def_msg('spirecv_msgs::2DTargets')

    sms_results["file_name"] = ""
    sms_results["height"] = int(h)
    sms_results["width"] = int(w)
    sms_results["targets"] = []
    if roi is not None:
        sms_results["rois"] = [[roi[0], roi[1], roi[2] - roi[0], roi[3] - roi[1]]]
    cls = np.array(result_classid).astype(np.float64)
    conf = np.array(result_scores).astype(np.float64)
    xyxy = np.array(result_boxes).astype(np.float64)
    for i in range(len(cls)):
        ann = dict()
        if int(cls[i]) < 0 or int(cls[i]) >= len(categories):
            continue
        ann["category_name"] = categories[int(cls[i])].strip().replace(' ', '_').lower()
        ann["category_id"] = int(cls[i])
        ann["score"] = round(conf[i], 3)
        ann["bbox"] = [round(j, 3) for j in xyxy[i].tolist()]
        ann["bbox"][2] = ann["bbox"][2] - ann["bbox"][0]
        ann["bbox"][3] = ann["bbox"][3] - ann["bbox"][1]
        if roi is not None:
            ann["bbox"][0] += roi[0]
            ann["bbox"][1] += roi[1]
        sms_results["targets"].append(ann)

    return sms_results


class YOLOv11DetNode_Trt(threading.Thread, BaseNode):
    def __init__(
        self,
        job_name: str,
        ip: str = '127.0.0.1',
        port: int = 9094,
        param_dict_or_file: Union[dict, str] = None,
        sms_shutdown: bool = True,
        **kwargs
    ):
        threading.Thread.__init__(self)
        BaseNode.__init__(
            self,
            self.__class__.__name__,
            job_name,
            ip=ip,
            port=port,
            param_dict_or_file=param_dict_or_file,
            sms_shutdown=sms_shutdown,
            **kwargs
        )
        self.ctx = None
        self.launch_next_emit = self.get_param("launch_next_emit", True)
        self.specified_input_topic = self.get_param("specified_input_topic", "")
        self.specified_output_topic = self.get_param("specified_output_topic", "")
        self.realtime_det = self.get_param("realtime_det", False)
        self.remote_ip = self.get_param("remote_ip", "127.0.0.1")
        self.remote_port = self.get_param("remote_port", 9094)
        self.confidence = self.get_param("confidence", 0.001)
        self.nms_thresh = self.get_param("nms_thresh", 0.60)
        self.imgsz = self.get_param("imgsz", 640)
        self.dataset_name = self.get_param("dataset_name", "coco_detection")
        self.engine_path = self.get_param("engine_path", "yolo11s.engine")
        self.dataset_categories = self.get_param("/dataset_categories", {})
        self.use_shm = self.get_param("use_shm", -1)
        self.params_help()

        self.b_use_shm = False
        if self.use_shm == 1 or (self.use_shm == -1 and platform.system() == 'Linux'):
            self.b_use_shm = True

        if self.engine_path.startswith("sms::"):
            self.local_engine_path = download_model(self.__class__.__name__, self.engine_path)
            assert self.local_engine_path is not None
        else:
            self.local_engine_path = self.engine_path

        input_url = '/' + job_name + '/sensor/image_raw' \
            if len(self.specified_input_topic) == 0 else self.specified_input_topic

        output_url = '/' + job_name + '/detector/results' \
            if len(self.specified_output_topic) == 0 else self.specified_output_topic

        self.job_queue = Queue()
        self.queue_pool.append(self.job_queue)

        self._image_reader = Subscriber(
            input_url, 'std_msgs::Null', self.image_callback,
            ip=ip, port=port
        )
        self._result_writer = Publisher(
            output_url, 'spirecv_msgs::2DTargets',
            ip=self.remote_ip, port=self.remote_port, qos=QoS.Reliability
        )
        self._show_writer = Publisher(
            '/' + job_name + '/detector/image_results', 'memory_msgs::RawImage' if self.b_use_shm else 'sensor_msgs::CompressedImage',
            ip=ip, port=port
        )
        if self.launch_next_emit:
            self._next_writer = Publisher(
                '/' + job_name + '/launch_next', 'std_msgs::Boolean',
                ip=ip, port=port, qos=QoS.Reliability
            )

        self.dataset_categories = self.dataset_categories[self.dataset_name]
        self.input_h = self.imgsz
        self.input_w = self.imgsz
        self.batch_size = 1

        current_file_path = os.path.dirname(os.path.abspath(__file__))
        PLUGIN_LIBRARY = os.path.join(current_file_path, "YOLOv11DetJetsonTensorRTLib/build/libmyplugins.so")
        ctypes.CDLL(PLUGIN_LIBRARY)

        # Create a Context on this device,
        self.ctx = cuda.Device(0).make_context()
        stream = cuda.Stream()
        TRT_LOGGER = trt.Logger(trt.Logger.INFO)
        runtime = trt.Runtime(TRT_LOGGER)
        
        # Deserialize the engine from file
        with open(self.local_engine_path, "rb") as f:
            engine = runtime.deserialize_cuda_engine(f.read())
        context = engine.create_execution_context()

        host_inputs = []
        cuda_inputs = []
        host_outputs = []
        cuda_outputs = []
        bindings = []
        
        for binding in engine:
            print('bingding:', binding, engine.get_tensor_shape(binding))
            self.batch_size = engine.get_tensor_shape(binding)[0]
            print('batch_size:', self.batch_size)
            size = trt.volume(engine.get_tensor_shape(binding)) * engine.max_batch_size
            dtype = np.float32  # trt.nptype(engine.get_tensor_dtype(binding))
            # Allocate host and device buffers
            host_mem = cuda.pagelocked_empty(size, dtype)
            cuda_mem = cuda.mem_alloc(host_mem.nbytes)
            # Append the device buffer to device bindings.
            bindings.append(int(cuda_mem))
            # Append to the appropriate list.
            if engine.binding_is_input(binding):
                self.input_w = engine.get_tensor_shape(binding)[-1]
                self.input_h = engine.get_tensor_shape(binding)[-2]
                host_inputs.append(host_mem)
                cuda_inputs.append(cuda_mem)
            else:
                host_outputs.append(host_mem)
                cuda_outputs.append(cuda_mem)

        # Store
        self.stream = stream
        self.context = context
        self.engine = engine
        self.host_inputs = host_inputs
        self.cuda_inputs = cuda_inputs
        self.host_outputs = host_outputs
        self.cuda_outputs = cuda_outputs
        self.bindings = bindings
        self.det_output_length = host_outputs[0].shape[0]
        
        print(self.get_raw_image_zeros().shape)
        self.infer(self.get_raw_image_zeros())
        
        self.infer_n_ims = 0
        self.infer_time3 = 0.0
        self.infer_time2 = 0.0
        self.infer_time1 = 0.0
        
        self.start()

    def __del__(self):
        if self.ctx is not None:
            self.ctx.pop()

    def release(self):
        BaseNode.release(self)
        self._image_reader.kill()
        self._result_writer.kill()
        self._show_writer.kill()
        self._next_writer.kill()

    def image_callback(self, msg):
        if self.realtime_det and not self.job_queue.empty():
            return

        img = sms2cvimg(msg)
        self.job_queue.put({'msg': msg, 'img': img})
    
    def get_raw_image_zeros(self):
        """
        description: Ready data for warmup
        """
        return np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8)

    def preprocess_image(self, raw_bgr_image):
        """
        description: Convert BGR image to RGB,
                     resize and pad it to target size, normalize to [0,1],
                     transform to NCHW format.
        param:
            input_image_path: str, image path
        return:
            image:  the processed image
            image_raw: the original image
            h: original height
            w: original width
        """
        image_raw = raw_bgr_image
        h, w, c = image_raw.shape
        image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
        # Calculate widht and height and paddings
        r_w = self.input_w / w
        r_h = self.input_h / h
        if r_h > r_w:
            tw = self.input_w
            th = int(r_w * h)
            tx1 = tx2 = 0
            ty1 = int((self.input_h - th) / 2)
            ty2 = self.input_h - th - ty1
        else:
            tw = int(r_h * w)
            th = self.input_h
            tx1 = int((self.input_w - tw) / 2)
            tx2 = self.input_w - tw - tx1
            ty1 = ty2 = 0
        # Resize the image with long side while maintaining ratio
        image = cv2.resize(image, (tw, th))
        # Pad the short side with (128,128,128)
        image = cv2.copyMakeBorder(
            image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128)
        )
        image = image.astype(np.float32)
        # Normalize to [0,1]
        image /= 255.0
        # HWC to CHW format:
        image = np.transpose(image, [2, 0, 1])
        # CHW to NCHW format
        image = np.expand_dims(image, axis=0)
        # Convert the image to row-major order, also known as "C order":
        image = np.ascontiguousarray(image)
        return image, image_raw, h, w
    
    def post_process(self, output, origin_h, origin_w):
        """
        description: postprocess the prediction
        param:
            output:     A numpy likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
            origin_h:   height of original image
            origin_w:   width of original image
        return:
            result_boxes: finally boxes, a boxes numpy, each row is a box [x1, y1, x2, y2]
            result_scores: finally scores, a numpy, each element is the score correspoing to box
            result_classid: finally classid, a numpy, each element is the classid correspoing to box
        """
        num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM
        # Get the num of boxes detected
        num = int(output[0])
        # Reshape to a two dimentional ndarray
        # pred = np.reshape(output[1:], (-1, 38))[:num, :]
        pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
        # Do nms
        boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=self.confidence, nms_thres=self.nms_thresh)
        result_boxes = boxes[:, :4] if len(boxes) else np.array([])
        result_scores = boxes[:, 4] if len(boxes) else np.array([])
        result_classid = boxes[:, 5] if len(boxes) else np.array([])
        return result_boxes, result_scores, result_classid
    
    def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
        """
        description: Removes detections with lower object confidence score than 'conf_thres' and performs
        Non-Maximum Suppression to further filter detections.
        param:
            prediction: detections, (x1, y1, x2, y2, conf, cls_id)
            origin_h: original image height
            origin_w: original image width
            conf_thres: a confidence threshold to filter detections
            nms_thres: a iou threshold to filter detections
        return:
            boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id)
        """
        # Get the boxes that score > self.confidence
        boxes = prediction[prediction[:, 4] >= conf_thres]
        # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
        boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
        # clip the coordinates
        boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
        boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
        boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
        boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)
        # Object confidence
        confs = boxes[:, 4]
        # Sort by the confs
        boxes = boxes[np.argsort(-confs)]
        # Perform non-maximum suppression
        keep_boxes = []
        while boxes.shape[0]:
            large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
            label_match = boxes[0, -1] == boxes[:, -1]
            # Indices of boxes with lower confidence scores, large IOUs and matching labels
            invalid = large_overlap & label_match
            keep_boxes += [boxes[0]]
            boxes = boxes[~invalid]
        boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
        return boxes
        
    def bbox_iou(self, box1, box2, x1y1x2y2=True):
        """
        description: compute the IoU of two bounding boxes
        param:
            box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
            box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
            x1y1x2y2: select the coordinate format
        return:
            iou: computed iou
        """
        if not x1y1x2y2:
            # Transform from center and width to exact coordinates
            b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
            b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
            b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
            b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
        else:
            # Get the coordinates of bounding boxes
            b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
            b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

        # Get the coordinates of the intersection rectangle
        inter_rect_x1 = np.maximum(b1_x1, b2_x1)
        inter_rect_y1 = np.maximum(b1_y1, b2_y1)
        inter_rect_x2 = np.minimum(b1_x2, b2_x2)
        inter_rect_y2 = np.minimum(b1_y2, b2_y2)
        # Intersection area
        inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None)
                      * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None))
        # Union Area
        b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
        b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

        iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

        return iou
    
    def xywh2xyxy(self, origin_h, origin_w, x):
        """
        description:    Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
        param:
            origin_h:   height of original image
            origin_w:   width of original image
            x:          A boxes numpy, each row is a box [center_x, center_y, w, h]
        return:
            y:          A boxes numpy, each row is a box [x1, y1, x2, y2]
        """
        y = np.zeros_like(x)
        r_w = self.input_w / origin_w
        r_h = self.input_h / origin_h
        if r_h > r_w:
            y[:, 0] = x[:, 0]
            y[:, 2] = x[:, 2]
            y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
            y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
            y /= r_w
        else:
            y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
            y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
            y[:, 1] = x[:, 1]
            y[:, 3] = x[:, 3]
            y /= r_h

        return y
    
    def infer(self, raw_image):
        # Make self the active context, pushing it on top of the context stack.
        self.ctx.push()
        # Restore
        stream = self.stream
        context = self.context
        host_inputs = self.host_inputs
        cuda_inputs = self.cuda_inputs
        host_outputs = self.host_outputs
        cuda_outputs = self.cuda_outputs
        bindings = self.bindings
        # Do image preprocess
        batch_image_raw = []
        batch_origin_h = []
        batch_origin_w = []
        batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w])

        input_image, image_raw, origin_h, origin_w = self.preprocess_image(raw_image)
        batch_image_raw.append(image_raw)
        batch_origin_h.append(origin_h)
        batch_origin_w.append(origin_w)
        np.copyto(batch_input_image[0], input_image)
        batch_input_image = np.ascontiguousarray(batch_input_image)

        # Copy input image to host buffer
        np.copyto(host_inputs[0], batch_input_image.ravel())
        start = time.time()
        # Transfer input data  to the GPU.
        cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
        # Run inference.
        context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
        # Transfer predictions back from the GPU.
        cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
        # Synchronize the stream
        stream.synchronize()
        end = time.time()
        # Remove any context from the top of the context stack, deactivating it.
        self.ctx.pop()
        # Here we use the first row of output in that batch_size = 1
        output = host_outputs[0]
        # Do postprocess

        result_boxes, result_scores, result_classid = self.post_process(
            output[0: self.det_output_length], batch_origin_h[0],
            batch_origin_w[0]
        )
        return result_boxes, result_scores, result_classid

    def run(self):
        while self.is_running():
            msg_dict = self.job_queue.get(block=True)
            if msg_dict is None:
                break

            t1 = time.time()
            msg, img = msg_dict['msg'], msg_dict['img']
            file_name = msg['file_name'] if 'file_name' in msg else ''

            self.infer_n_ims += 1
            self.infer_time1 += time.time() - t1
            print('avg-infer time1: {}'.format(self.infer_time1 / self.infer_n_ims))

            t1 = time.time()

            if "rois" in msg and len(msg["rois"]) > 0:
                roi = msg["rois"][0]
                img_infer = img[roi[1]: roi[3], roi[0]: roi[2], :]
            else:
                roi = None
                img_infer = img
            # DO Object Detection
            result_boxes, result_scores, result_classid = self.infer(img_infer)

            res_msg = trans_det_results(result_boxes, result_scores, result_classid, img.shape[1], img.shape[0], self.dataset_categories, roi)
            res_msg['file_name'] = file_name
            res_msg['dataset'] = self.dataset_name
            if 'client_id' in msg:
                res_msg['client_id'] = msg['client_id']
            if 'file_name' in msg:
                res_msg['file_name'] = msg['file_name']
            if 'img_id' in msg:
                res_msg['img_id'] = msg['img_id']
            if 'img_total' in msg:
                res_msg['img_total'] = msg['img_total']
            res_msg['time_used'] = time.time() - t1

            self.infer_time2 += time.time() - t1
            print('avg-infer time2: {}'.format(self.infer_time2 / self.infer_n_ims))

            t1 = time.time()

            self._result_writer.publish(res_msg)

            if 'img_total' in msg and self.launch_next_emit:
                next_msg = def_msg('std_msgs::Boolean')
                next_msg['data'] = True
                self._next_writer.publish(next_msg)

            if self.b_use_shm:
                msg = self._show_writer.cvimg2sms_mem(img)
            msg['spirecv_msgs::2DTargets'] = res_msg
            self._show_writer.publish(msg)

            self.infer_time3 += time.time() - t1
            print('avg-infer time3: {}, totally {} imgs'.format(self.infer_time3 / self.infer_n_ims, self.infer_n_ims))
            # END

        self.release()
        print('{} quit!'.format(self.__class__.__name__))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        type=str,
        default='default_params.json',
        help='SpireCV2 Config (.json)')
    parser.add_argument(
        '--job-name', '-j',
        type=str,
        default='live',
        help='SpireCV Job Name')
    parser.add_argument(
        '--ip',
        type=str,
        default='127.0.0.1',
        help='SpireMS Core IP')
    parser.add_argument(
        '--port',
        type=int,
        default=9094,
        help='SpireMS Core Port')
    args, unknown_args = parser.parse_known_args()
    if not os.path.isabs(args.config):
        current_path = os.path.abspath(__file__)
        params_dir = os.path.join(current_path[:current_path.find('spirecv-pro') + 11], 'params', 'spirecv2')
        args.config = os.path.join(params_dir, args.config)
    print("--config:", args.config)
    print("--job-name:", args.job_name)
    extra = get_extra_args(unknown_args)

    node = YOLOv11DetNode_Trt(args.job_name, param_dict_or_file=args.config, ip=args.ip, port=args.port, **extra)
    node.join()
